{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "matplotlib.get_backend :  module://ipykernel.pylab.backend_inline\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "GPU_id = 2\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = str(GPU_id)\n",
    "\n",
    "from fastai.basic_train import *\n",
    "from fastai.callbacks import SaveModelCallback\n",
    "from functools import partial\n",
    "from torch.utils.dlpack import from_dlpack\n",
    "import cudf as gd\n",
    "import warnings\n",
    "import glob \n",
    "\n",
    "\n",
    "from mpnn_model.common import * \n",
    "from mpnn_model.common_constants import * \n",
    "from mpnn_model.dataset import TensorBatchDataset, BatchDataBunch, BatchDataLoader\n",
    "from mpnn_model.data_collate import tensor_collate_rnn\n",
    "from mpnn_model.GaussRank import GaussRankMap\n",
    "from mpnn_model.helpers import load_cfg\n",
    "from mpnn_model.model import Net \n",
    "from mpnn_model.train_loss import train_criterion, lmae_criterion\n",
    "from mpnn_model.callback import get_reverse_frame, lmae, LMAE\n",
    "from mpnn_model.radam import * \n",
    "from mpnn_model.build_predictions import do_test \n",
    "from mpnn_model.helpers import * \n",
    "\n",
    "warnings.filterwarnings(\"ignore\") \n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load config file "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load config dict \n",
    "cfg = load_cfg('/rapids/notebooks/srabhi/champs-2019/CherKeng_solution/fastai_code/experiments/MPNN_RNN_PREDICT_TYPE_LMAE_WO_GAUSSRANK.yaml')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "COUPLING_MAX = 136"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<center> <h1> Data set : </h1> </center>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATA_DIR = cfg['dataset']['input_path']\n",
    "fold = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 42.2 s, sys: 17.1 s, total: 59.3 s\n",
      "Wall time: 2min 19s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "validation = gd.read_parquet(DATA_DIR +'/rnn_parquet/fold_%s/validation.parquet'%fold)\n",
    "train = gd.read_parquet(DATA_DIR +'/rnn_parquet/fold_%s/train.parquet' %fold)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = cfg['train']['batch_size']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Convert dataframe to tensors "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "** Convert validation tensors **\n",
      "\n"
     ]
    }
   ],
   "source": [
    "num_nodes_tensor = from_dlpack(train['num_nodes'].to_dlpack()).long()\n",
    "num_edges_tensor = from_dlpack(train['num_edge'].to_dlpack()).long()\n",
    "num_coupling_tensor = from_dlpack(train['num_coupling'].to_dlpack()).long()\n",
    "node_cols = [i for i in train.columns if re.compile(\"^node_[0-9]+\").findall(i)]\n",
    "nodes_matrix = from_dlpack(train[node_cols].to_dlpack()).type(torch.float32)\n",
    "edge_cols = [i for i in train.columns if re.compile(\"^edge_[0-9]+\").findall(i)]\n",
    "edges_matrix = from_dlpack(train[edge_cols].to_dlpack()).type(torch.float32)\n",
    "coupling_cols = [i for i in train.columns if re.compile(\"^coupling_[0-9]+\").findall(i)]\n",
    "coupling_matrix = from_dlpack(train[coupling_cols].to_dlpack()).type(torch.float32)\n",
    "mol_train = train.molecule_name.unique().to_pandas().values\n",
    "train_dataset = TensorBatchDataset(mol_train, \n",
    "                                tensors=[nodes_matrix, edges_matrix, coupling_matrix,\n",
    "                                        num_nodes_tensor, num_edges_tensor, num_coupling_tensor], \n",
    "                                batch_size=batch_size,\n",
    "                                collate_fn=tensor_collate_rnn,\n",
    "                                COUPLING_MAX=COUPLING_MAX,\n",
    "                                mode='train',\n",
    "                                csv='train')\n",
    "del train\n",
    "# convert validation to tensors \n",
    "print('** Convert validation tensors **\\n')\n",
    "num_nodes_tensor = from_dlpack(validation['num_nodes'].to_dlpack()).long()\n",
    "num_edges_tensor = from_dlpack(validation['num_edge'].to_dlpack()).long()\n",
    "num_coupling_tensor = from_dlpack(validation['num_coupling'].to_dlpack()).long()\n",
    "node_cols = [i for i in validation.columns if re.compile(\"^node_[0-9]+\").findall(i)]\n",
    "nodes_matrix = from_dlpack(validation[node_cols].to_dlpack()).type(torch.float32)\n",
    "edge_cols = [i for i in validation.columns if re.compile(\"^edge_[0-9]+\").findall(i)]\n",
    "edges_matrix = from_dlpack(validation[edge_cols].to_dlpack()).type(torch.float32)\n",
    "coupling_cols = [i for i in validation.columns if re.compile(\"^coupling_[0-9]+\").findall(i)]\n",
    "coupling_matrix = from_dlpack(validation[coupling_cols].to_dlpack()).type(torch.float32)\n",
    "mol_valid = validation.molecule_name.unique().to_pandas().values\n",
    "valid_dataset = TensorBatchDataset(mol_valid, \n",
    "                                tensors=[nodes_matrix, edges_matrix, coupling_matrix,\n",
    "                                            num_nodes_tensor, num_edges_tensor, num_coupling_tensor], \n",
    "                                batch_size=batch_size,\n",
    "                                collate_fn=tensor_collate_rnn,\n",
    "                                COUPLING_MAX=COUPLING_MAX,\n",
    "                                mode='train',\n",
    "                                csv='train')\n",
    "del validation "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = BatchDataBunch.create(train_dataset, valid_dataset, device='cuda', bs=batch_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<center> <h1> MPNN + 3-seq RNN model </h1></center>\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = Net(cfg, y_range=[-36.2186, 204.8800])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Init Fastai Learner"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tCriterion: lmaeo2ceha\n",
      "\n",
      "\n",
      " Load GaussRank mapping\n",
      "\tTraining loss: functools.partial(<function train_criterion at 0x7f4c0d35b950>, criterion='lmaeo2ceha', num_output=1, gaussrank=False, pred_type=True)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "#### Init Fastai learner \n",
    "loss_name = cfg['train']['loss_name']\n",
    "num_output = cfg['model']['regression']['num_output']\n",
    "predict_type = cfg['model']['regression']['predict_type']\n",
    "gaussrank = cfg['dataset']['gaussrank']\n",
    "print('\\tCriterion: %s\\n'%(loss_name))\n",
    "\n",
    "### Get GaussRank mapping \n",
    "print('\\n Load GaussRank mapping')\n",
    "data_dir = DATA_DIR + '/rnn_parquet'\n",
    "normalize = cfg['dataset']['normalize']\n",
    "files = glob.glob(data_dir+'/fold_%s/'%fold+'*.csv')\n",
    "mapping_frames = ['']*8\n",
    "coupling_order = ['']*8\n",
    "\n",
    "for file in files:\n",
    "    type_ = file.split('/')[-1].split('_')[2]\n",
    "    order = int(file.split('/')[-1].split('_')[-1].strip('.csv'))\n",
    "    coupling_order[order] = type_\n",
    "    mapping_frames[order] = pd.read_csv(file)  \n",
    "\n",
    "grm = GaussRankMap(mapping_frames, coupling_order)\n",
    "\n",
    "optal = partial(RAdam)\n",
    "\n",
    "learn =  Learner(data,\n",
    "                 net.cuda(),\n",
    "                 metrics=None,\n",
    "                 opt_func=optal,\n",
    "                 callback_fns=partial(LMAE,\n",
    "                                    grm=grm,\n",
    "                                    predict_type=predict_type,\n",
    "                                    normalize_coupling=normalize,\n",
    "                                    coupling_rank=gaussrank))\n",
    "\n",
    "learn.loss_func = partial(train_criterion, \n",
    "                          criterion=loss_name,\n",
    "                          num_output=num_output,\n",
    "                          gaussrank=gaussrank,\n",
    "                          pred_type=predict_type) \n",
    "\n",
    "print('\\tTraining loss: %s\\n'%(learn.loss_func))\n",
    "\n",
    "#### fit one cycle \n",
    "epochs = cfg['train']['epochs']\n",
    "max_lr = cfg['train']['max_lr']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fit_one_cycle "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>LMAE</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>0.188517</td>\n",
       "      <td>0.089909</td>\n",
       "      <td>0.008658</td>\n",
       "      <td>03:16</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Better model found at epoch 0 with LMAE value: 0.008657947182655334.\n"
     ]
    }
   ],
   "source": [
    "learn.fit_one_cycle(1,\n",
    "                    0.005, \n",
    "                    callbacks=[SaveModelCallback(learn,\n",
    "                                                 every='improvement',\n",
    "                                                 monitor='LMAE', \n",
    "                                                 name=cfg['train']['model_name']+'_fold_%s'%fold,\n",
    "                                                 mode='min')])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h1> <center> Build predictions </center></h1>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "compute the validation predictions \n",
      "  1164554/ 1164554     1.00   0 hr 00 min\n",
      "\n",
      "predict\n",
      "build preds frame\n",
      "Compute lmae per type\n"
     ]
    }
   ],
   "source": [
    "valid_dataset = TensorBatchDataset(mol_valid, \n",
    "                                tensors=[nodes_matrix, edges_matrix, coupling_matrix,\n",
    "                                        num_nodes_tensor, num_edges_tensor, num_coupling_tensor], \n",
    "                                batch_size=batch_size,\n",
    "                                collate_fn=tensor_collate_rnn,\n",
    "                                COUPLING_MAX=COUPLING_MAX,\n",
    "                                mode='test',\n",
    "                                csv='train')\n",
    "\n",
    "valid_loader = BatchDataLoader(valid_dataset, \n",
    "                               shuffle=False, \n",
    "                               pin_memory=False, \n",
    "                               drop_last=False, \n",
    "                               device='cuda')\n",
    "\n",
    "valid_dataset.get_total_samples()\n",
    "print('compute the validation predictions ')    \n",
    "valid_loss, reverse_frame, contributions, molecule_representation = do_test(learn.model,\n",
    "                                                                       valid_loader,\n",
    "                                                                       valid_dataset.total_samples,\n",
    "                                                                       1,\n",
    "                                                                       predict_type,\n",
    "                                                                       grm,\n",
    "                                                                       normalize=normalize,\n",
    "                                                                       gaussrank=gaussrank)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "|------------------------------------ VALID ------------------------------------------------|\n",
      "\n",
      "| 1JHC,   2JHC,   3JHC,   1JHN,   2JHN,   3JHN,   2JHH,   3JHH  |  loss  mae log_mae | fold |\n",
      "\n",
      "|-------------------------------------------------------------------------------------------|\n",
      "\n",
      "|+0.417, -0.314, +0.214, +0.593, -0.243, -0.596, -0.245, +0.242 | +0.085  1.09 +0.01 |  1   |\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print('\\n')\n",
    "print('|------------------------------------ VALID ------------------------------------------------|\\n')\n",
    "print('| 1JHC,   2JHC,   3JHC,   1JHN,   2JHN,   3JHN,   2JHH,   3JHH  |  loss  mae log_mae | fold |\\n')\n",
    "print('|-------------------------------------------------------------------------------------------|\\n')\n",
    "print('|%+0.3f, %+0.3f, %+0.3f, %+0.3f, %+0.3f, %+0.3f, %+0.3f, %+0.3f | %+5.3f %5.2f %+0.2f |  %s   |\\n' %(*valid_loss[:11], fold))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test data "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATA_DIR = cfg['dataset']['input_path']\n",
    "batch_size = cfg['train']['batch_size']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load test data\n",
      "\n",
      " Compute predictions for test data at fold 1\n",
      "\n",
      "  2505542/ 2505542     1.00   0 hr 00 min\n",
      "\n",
      "predict\n",
      "compute the reverse frame\n",
      "Compute lmae per type\n"
     ]
    }
   ],
   "source": [
    "print('load test data')\n",
    "torch.cuda.empty_cache()\n",
    "test = gd.read_parquet(DATA_DIR +'/rnn_parquet/test.parquet')\n",
    "num_nodes_tensor = from_dlpack(test['num_nodes'].to_dlpack())\n",
    "num_edges_tensor = from_dlpack(test['num_edge'].to_dlpack())\n",
    "num_coupling_tensor = from_dlpack(test['num_coupling'].to_dlpack())\n",
    "node_cols = [i for i in test.columns if re.compile(\"^node_[0-9]+\").findall(i)]\n",
    "nodes_matrix = from_dlpack(test[node_cols].to_dlpack())\n",
    "nodes_matrix = from_dlpack(test[node_cols].to_dlpack()).type(torch.float32)\n",
    "edge_cols = [i for i in test.columns if re.compile(\"^edge_[0-9]+\").findall(i)]\n",
    "edges_matrix = from_dlpack(test[edge_cols].to_dlpack()).type(torch.float32)\n",
    "coupling_cols = [i for i in test.columns if re.compile(\"^coupling_[0-9]+\").findall(i)]\n",
    "coupling_matrix = from_dlpack(test[coupling_cols].to_dlpack()).type(torch.float32)\n",
    "\n",
    "mol_test  = test.molecule_name.unique().to_pandas().values\n",
    "del test\n",
    "\n",
    "test_dataset = TensorBatchDataset(mol_test, \n",
    "                                tensors=[nodes_matrix, edges_matrix, coupling_matrix,\n",
    "                                         num_nodes_tensor, num_edges_tensor, num_coupling_tensor], \n",
    "                                batch_size=batch_size,\n",
    "                                collate_fn=tensor_collate_rnn,\n",
    "                                COUPLING_MAX=COUPLING_MAX,\n",
    "                                mode='test',\n",
    "                                csv='test')\n",
    "\n",
    "test_loader = BatchDataLoader(test_dataset, \n",
    "                               shuffle=False, \n",
    "                               pin_memory=False, \n",
    "                               drop_last=False, \n",
    "                               device='cuda')\n",
    "\n",
    "print('\\n Compute predictions for test data at fold %s\\n' %fold)\n",
    "test_loss, preds_fold_test, contributions, molecule_representation = do_test(learn.model,\n",
    "                                                                       test_loader,\n",
    "                                                                       cfg['train']['test_shape'], \n",
    "                                                                       1,\n",
    "                                                                       predict_type,\n",
    "                                                                       grm,\n",
    "                                                                       normalize=False,\n",
    "                                                                       gaussrank=gaussrank)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Save validation and test frames "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "val_loss = valid_loss[-1]\n",
    "print('\\n Save Validation frame' )\n",
    "out_dir = '/rapids/notebooks/srabhi/champs-2019/output'\n",
    "clock = \"{}\".format(datetime.now()).replace(' ','-').replace(':','-').split('.')[0]\n",
    "output_name = out_dir + '/submit/scalar_output/cv_%s_%s_%.4f_fold_%s.csv.gz'%(clock, loss_name, val_loss, fold)\n",
    "reverse_frame.to_csv(output_name, index=False,compression='gzip')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# save test predictions \n",
    "print('\\n Save Test frame' )\n",
    "out_dir =   cfg['dataset']['output_path']\n",
    "clock = \"{}\".format(datetime.now()).replace(' ','-').replace(':','-').split('.')[0]\n",
    "output_name = out_dir + '/submit/scalar_output/sub_%s_%s_%.4f_fold_%s.csv.gz'%(clock, loss_name, val_loss, fold)\n",
    "preds_fold_test.to_csv(output_name, index=False,compression='gzip')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
